Monet Painting GAN Project¶
Author: Grant Novota
Overview¶
Goal: to build a Generative Adversarial Network (GAN), specifically a CycleGAN model, capable of generating images in the style of Monet.
A GAN consists of at least two neural networks: a generator and a discriminator. These two models work against each other--the generator tries to produce realistic images to fool the discriminator, while the discriminator attempts to distinguish between real and generated images.
Data¶
The dataset contains two TFRecord files, with all images having three color channels (RGB):
- monet_tfrec - 300 Monet paintings, each sized 256x256 pixels, in TFRecord format
- photo_tfrec - 7028 photos, each sized 256x256 pixels, in TFRecord format
Data Source: Amy Jang, Ana Sofia Uzsoy, and Phil Culliton. I’m Something of a Painter Myself. https://kaggle.com/competitions/gan-getting-started, 2020. Kaggle.
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv
import os
import PIL
import shutil
import tensorflow as tf
from tensorflow import keras
from keras.models import Model
from keras.layers import Layer, Input, Conv2D, Conv2DTranspose, GroupNormalization
from keras.layers import Activation, ReLU, LeakyReLU, Add
from keras.optimizers import Adam
from keras.ops import pad, ones_like, zeros_like
from keras.utils import register_keras_serializable
from keras.losses import MeanSquaredError, MeanAbsoluteError
from keras.callbacks import ModelCheckpoint
Exploratory Data Analysis¶
# Get the file names from monet_tfrec and photo_tfrec
monet_file = tf.io.gfile.glob('/kaggle/input/gan-getting-started/monet_tfrec/*.tfrec')
photo_file = tf.io.gfile.glob('/kaggle/input/gan-getting-started/photo_tfrec/*.tfrec')
As described in Introduction, all images have a resolution of 256×256 pixels with 3 channels (RGB). I will define a function to decode the images from the TFRecord files and scale the pixel values to the range [-1, 1], as this is a common practice when training CycleGAN models.
# Define a function to decode images from TFRecord files
def decode_img(img):
image = tf.io.decode_jpeg(img, channels = 3)
image = (tf.cast(image, tf.float32) / 127.5) - 1
image = tf.reshape(image, [256, 256, 3])
return image
# Define a function to read images from TFRecord files
def read_tfrecord(tfdata):
tfrecord_format = {
"image_name": tf.io.FixedLenFeature([], tf.string),
"image": tf.io.FixedLenFeature([], tf.string),
"target": tf.io.FixedLenFeature([], tf.string)
}
tfdata = tf.io.parse_single_example(tfdata, tfrecord_format)
image = decode_img(tfdata['image'])
return image
# Define a function to load images from the TFRecord files
def load_dataset(filenames):
# disable order, increase speed
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
# read images from multiple files if available
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.with_options(ignore_order)
dataset = dataset.map(read_tfrecord)
return dataset
# Load the datasets
monet_ds = load_dataset(monet_file).batch(1)
photo_ds = load_dataset(photo_file).batch(1)
Now, let's take a look at the images from each dataset.
# Plot images from Monet dataset
monets = iter(monet_ds)
fig, axes = plt.subplots(nrows = 1, ncols = 3, figsize = (15, 5))
fig.suptitle('Monet', fontsize = 16)
ax = axes.flatten()
for i in range(3):
monet = next(monets)
ax[i].imshow(monet[0] * 0.5 + 0.5)
ax[i].axis('off')
plt.show()
# Plot image from photo dataset
photos = iter(photo_ds)
fig, axes = plt.subplots(nrows = 1, ncols = 3, figsize = (15, 5))
fig.suptitle('Photo', fontsize = 16)
ax = axes.flatten()
for i in range(3):
photo = next(photos)
ax[i].imshow(photo[0] * 0.5 + 0.5)
ax[i].axis('off')
plt.show()
Check the histogram of the RGB channels of both Monet paintings and photos, to see if there is any patterns.
# Define a function to load sample images from each file
def sample_imgs(file_path, num = 300):
images = []
filenames = [f for f in os.listdir(file_path) if f.endswith('.jpg')][:num]
for filename in filenames:
img_path = os.path.join(file_path, filename)
image = cv.imread(img_path)
images.append(image)
return images
# Load 300 images from each file
monet_sample = sample_imgs('/kaggle/input/gan-getting-started/monet_jpg')
photo_sample = sample_imgs('/kaggle/input/gan-getting-started/photo_jpg')
# Plot color histogram
fig,axes = plt.subplots(2, 3, figsize = (18, 8))
colors = ['blue', 'green', 'red']
for i, col in enumerate(colors):
hist_monet = cv.calcHist(monet_sample, [i], None, [256], [0, 256])
axes[0, i].plot(hist_monet, color = col)
axes[0, i].set_xlim([0, 256])
axes[0, i].set_title(f'Monet: {col}')
hist_photo = cv.calcHist(photo_sample, [i], None, [256], [0, 256])
axes[1, i].plot(hist_monet, color = col)
axes[1, i].set_xlim([0, 256])
axes[1, i].set_title(f'Photo: {col}')
plt.show()
There is no obvious pattern in the density of the three color channels between Monet paintings and photos.
Modeling¶
I will build a CycleGAN model, which is known for unpaired image-to-image translation. For example, it can translate an image of a horse to an image of a zebra, and vice versa.
A CycleGAN model consists of two generators and two discriminators (the detailed architectures will be described during model building), and involves three loss functions:
- Two generators:
- Generator G translates an image X(photo) to Y(Monet);
- Generator F translates an image Y(Monet) to X(photo);
- Two discriminators:
- Discriminator Dy indentifies whether an image is a real Monet painting or not;
- Discriminator Dx identifies whether an image is a real photo or not;
- Three loss functions:
- Adversarial loss: similar to a standard GAN, the discriminators aim to maximize their accuracy in distinguishing real / fake images, while the generators try to fool the discriminators (aiming to minimize discriminators' accuracy in identifying fake images).
- Cycle consistency loss: measures the similarity between x and $\hat{x}$, and between y and $\hat{y}$, see the graph (b) and (c) above. The main idea here is that the twice-transformed image $\hat{x}$ (x → $\hat{y}$ → $\hat{x}$) should be similar to the original image x, ensuring that the two generators are consistent with each other.
- Identity loss: measures the similarity between x and $\hat{x}$, when x is passed through the generator which is expected to generate x. For example, if a real Monet painting is passed through the generator designed to generate Monet paintings, the outout should remain unchanged. This ensures that the generator does not modify an image when there is no need to.
Prepare the data for CycleGAN¶
As we need to train the model on both Monet paintings and photos in each batch, but our images are in two separate datasets--and we only have 300 Monet paintings in monet_dataset while there are over 7000 photos in photo_dataset--I will define a function to repeat both datasets indefinitely, shuffle them, and then zip them together for CycleGAN training.
# Define a function to prepare the data
def CycleGan_dataset(monet_dataset, photo_dataset):
monet_ds = monet_dataset.repeat()
photo_ds = photo_dataset.repeat()
monet_ds = monet_ds.shuffle(2048)
photo_ds = photo_ds.shuffle(2048)
gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
return gan_ds
# Get the training set
CycleGan_ds = CycleGan_dataset(monet_ds, photo_ds)
Build CycleGAN model¶
1. Define the building blocks
# Weights initializer for the layers
kernel_init = keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
# Weights initializer for the instance normalization
gamma_init = keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
# Define a function for downsampling block
def downsample(x, filters, kernel_size, strides, activation = 'relu'):
x = Conv2D(filters, kernel_size, strides = strides, padding = 'same',
kernel_initializer = kernel_init, use_bias = False)(x)
x = GroupNormalization(groups = filters, gamma_initializer = gamma_init)(x) # Instance Normalization
x = Activation(activation)(x)
return x
# Define a function for upsampling block
def upsample(x, filters, kernel_size, strides, activation = 'relu'):
x = Conv2DTranspose(filters, kernel_size, strides = strides, padding = 'same',
kernel_initializer = kernel_init, use_bias = False)(x)
x = GroupNormalization(groups = filters, gamma_initializer = gamma_init)(x)
x = Activation(activation)(x)
return x
# Define a layer of reflection padding via subclassing
@register_keras_serializable()
class reflection_padding(Layer):
def __init__(self, padding = (1, 1), **kwargs):
self.padding = tuple(padding)
super().__init__(**kwargs)
def call(self, input_tensor):
pad_x, pad_y = self.padding
pad_width = [
[0, 0], # no padding on batch axis
[pad_y, pad_y], # padding on image height axis
[pad_x, pad_x], # padding on image width axis
[0, 0] # no padding on channel axis
]
return pad(input_tensor, pad_width, mode = "reflect")
def get_config(self):
config = super.get_config()
config.update({"padding": self.padding})
return config
# Define a layer of residule block via subclassing
@register_keras_serializable()
class residual_block(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.padding = reflection_padding()
def build(self, input_shape):
filters = input_shape[-1]
self.conv1 = Conv2D(filters, kernel_size = (3, 3),
kernel_initializer = kernel_init, use_bias = False)
self.conv2 = Conv2D(filters, kernel_size = (3, 3),
kernel_initializer = kernel_init, use_bias = False)
self.instance_norm1 = GroupNormalization(groups = filters, gamma_initializer = gamma_init)
self.instance_norm2 = GroupNormalization(groups = filters, gamma_initializer = gamma_init)
def call(self, inputs):
input_tensor = inputs
# first conv block
x = self.padding(inputs)
x = self.conv1(x)
x = self.instance_norm1(x)
x = Activation('relu')(x)
# second conv block
x = self.padding(x)
x = self.conv2(x)
x = self.instance_norm2(x)
x = Activation('relu')(x)
# output
x = Add()([input_tensor, x])
return x
def get_config(self):
return super.get_config()
2. Build the generator
The architecture of the generator is as following:
[c7s1-64] -[d128] - [d256] - [R256] × 9 - [u128] - [u64] - [c7s1-3]
- [c7s1-64]: 7 × 7 Convolution-InstanceNormalization- ReLU layer with 64 filters and stride 1
- [d128]-[d256]: downsampling blocks, one with 128 filters and the other with 256 filters
- [R256]×9: 9 residual blocks, each with 256 filters
- [u128]-[u64]: upsampling blocks, one with 128 filters and the other with 64 filters
- [c7s1-3]: 7 × 7 Convolution-InstanceNormalization- ReLU layer with 3 filters and stride 1
# Build the generator
def generator(filters = 64, num_downsample = 2, num_residual = 9,
num_upsample = 2, name = None):
inputs = Input(shape = [256, 256, 3])
# First convolutional block
x = reflection_padding(padding = (3, 3))(inputs)
x = Conv2D(filters, kernel_size = (7, 7),
kernel_initializer = kernel_init, use_bias = False)(x)
x = GroupNormalization(groups = filters, gamma_initializer = gamma_init)(x)
x = Activation('relu')(x)
# Downsampling block
for _ in range(num_downsample):
filters *= 2
x = downsample(x, filters, (3, 3), (2, 2), activation = 'relu')
# Residual block
for _ in range(num_residual):
x = residual_block()(x)
# Upsampling block
for _ in range(num_upsample):
filters //= 2
x = upsample(x, filters, (3, 3), (2, 2), activation = 'relu')
# Final block
x = reflection_padding(padding = (3, 3))(x)
x = Conv2D(3, kernel_size = (7, 7),
kernel_initializer = kernel_init, use_bias = False)(x)
x = GroupNormalization(groups = 3, gamma_initializer = gamma_init)(x)
x = Activation('tanh')(x)
return Model(inputs = inputs, outputs = x)
3. Build the discriminator
The architecture of the discriminator is as following:
[c4s2-64] -[d128] - [d256] - [d512] - [c4s1-1]
- [c7s1-64]: 4 × 4 Convolution-leaky ReLU layer with 64 filters and stride 2
- [d128]-[d256]-[d512]: downsampling blocks, with 128, 256, 512 filters respectively
- [c4s1-1]: 4 × 4 Convolution-leaky ReLU layer with 1 filters and stride 1
# Build the discriminator
def discriminator(filters = 64, num_downsample = 3, name = None):
inputs = Input(shape = [256, 256, 3])
# First convolutional block
x = Conv2D(filters, kernel_size = (4, 4), strides = (2, 2), padding = 'same',
kernel_initializer = kernel_init, use_bias = False)(inputs)
x = Activation('leaky_relu')(x)
# Downsampling block
for _ in range(num_downsample):
filters *= 2
x = downsample(x, filters, (4, 4), (2, 2), activation = 'leaky_relu')
# Final block
x = Conv2D(1, kernel_size = (4, 4), padding = 'same',
kernel_initializer = kernel_init, use_bias = False)(x)
return Model(inputs = inputs, outputs = x)
4. Define the loss functions
# Define the functions of adversarial loss
def generator_loss(dis_fake):
loss_fn = MeanSquaredError()
return loss_fn(ones_like(dis_fake), dis_fake)
def discriminator_loss(dis_real, dis_fake):
loss_fn = MeanSquaredError()
real_loss = loss_fn(ones_like(dis_real), dis_real)
fake_loss = loss_fn(zeros_like(dis_fake), dis_fake)
return (real_loss + fake_loss) * 0.5
# Define the function of cycle consistency loss
def cycle_loss(img, cycled_img, Lambda):
loss_fn = MeanAbsoluteError()
return Lambda * loss_fn(img, cycled_img)
# Define the function of identity loss
def identity_loss(img, same_img, Lambda):
loss_fn = MeanAbsoluteError()
return 0.5 * Lambda * loss_fn(img, same_img)
5. Build the CycleGAN model
# Build the CycleGAN model via subclassing
class CycleGAN(Model):
def __init__(self, generator_monet, generator_photo, discriminator_monet,
discriminator_photo, lambda_cycle = 10):
super().__init__()
self.GenM = generator_monet
self.GenP = generator_photo
self.DisM = discriminator_monet
self.DisP = discriminator_photo
self.Lamd = lambda_cycle
def compile(self,
genM_optimizer, genP_optimizer, disM_optimizer, disP_optimizer,
gen_loss, dis_loss, cycle_loss, identity_loss):
super().compile()
self.GenM_Opt = genM_optimizer
self.GenP_Opt = genP_optimizer
self.DisM_Opt = disM_optimizer
self.DisP_Opt = disP_optimizer
self.gen_loss = gen_loss
self.dis_loss = dis_loss
self.cycle_loss = cycle_loss
self.identity_loss = identity_loss
def train_step(self, batch_data):
real_monet, real_photo = batch_data
with tf.GradientTape(persistent = True) as tape:
# photo → monet → photo
fake_monet = self.GenM(real_photo, training = True)
cycled_photo = self.GenP(fake_monet, training = True)
# monet → photo → monet
fake_photo = self.GenP(real_monet, training = True)
cycled_monet = self.GenM(fake_photo, training = True)
# identity mapping
same_monet = self.GenM(real_monet, training = True)
same_photo = self.GenP(real_photo, training = True)
# discriminator output
dis_real_monet = self.DisM(real_monet, training = True)
dis_fake_monet = self.DisM(fake_monet, training = True)
dis_real_photo = self.DisP(real_photo, training = True)
dis_fake_photo = self.DisP(fake_photo, training = True)
# generator adversarial loss
genM_adver = self.gen_loss(dis_fake_monet)
genP_adver = self.gen_loss(dis_fake_photo)
# cycle loss
cycle_photo = self.cycle_loss(real_photo, cycled_photo, self.Lamd)
cycle_monet = self.cycle_loss(real_monet, cycled_monet, self.Lamd)
total_cycle = cycle_photo + cycle_monet
# identity loss
genM_identity = self.identity_loss(real_monet, same_monet, self.Lamd)
genP_identity = self.identity_loss(real_photo, same_photo, self.Lamd)
# total generator loss
genM_loss = genM_adver + total_cycle + genM_identity
genP_loss = genP_adver + total_cycle + genP_identity
# discriminator loss
disM_loss = self.dis_loss(dis_real_monet, dis_fake_monet)
disP_loss = self.dis_loss(dis_real_photo, dis_fake_photo)
# calculate gradients for generators
grads_GenM = tape.gradient(genM_loss, self.GenM.trainable_variables)
grads_GenP = tape.gradient(genP_loss, self.GenP.trainable_variables)
# calculate gradients for discriminators
grads_DisM = tape.gradient(disM_loss, self.DisM.trainable_variables)
grads_DisP = tape.gradient(disP_loss, self.DisP.trainable_variables)
# update weights of generators and discriminators
self.GenM_Opt.apply_gradients(zip(grads_GenM, self.GenM.trainable_variables))
self.GenP_Opt.apply_gradients(zip(grads_GenP, self.GenP.trainable_variables))
self.DisM_Opt.apply_gradients(zip(grads_DisM, self.DisM.trainable_variables))
self.DisP_Opt.apply_gradients(zip(grads_DisP, self.DisP.trainable_variables))
return {
"genM_loss": genM_loss,
"genP_loss": genP_loss,
"disM_loss": disM_loss,
"disP_loss": disP_loss
}
Train the CycleGAN model¶
Train the CycleGAN model on our dataset. I will set a model checkpoint to save the weights in case the runtime is disconnected during long training, and then train the model for 50 epochs first.
# Define model check_point
check_point = ModelCheckpoint(
filepath = '/kaggle/working/cyclegan1.weights.h5',
save_weights_only = True
)
# Train the model
CycleGan = CycleGAN(generator_monet = generator(name = 'gen_monet'),
generator_photo = generator(name = 'gen_photo'),
discriminator_monet = discriminator(name = 'dis_monet'),
discriminator_photo = discriminator(name = 'dis_photo'))
CycleGan.compile(genM_optimizer = Adam(learning_rate = 0.0002),
genP_optimizer = Adam(learning_rate = 0.0002),
disM_optimizer = Adam(learning_rate = 0.0002),
disP_optimizer = Adam(learning_rate = 0.0002),
gen_loss = generator_loss,
dis_loss = discriminator_loss,
cycle_loss = cycle_loss,
identity_loss = identity_loss)
CycleGan.fit(CycleGan_ds,
steps_per_epoch = 300,
epochs = 50,
callbacks = [check_point])
Epoch 1/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 219s 272ms/step - disM_loss: 0.0679 - disP_loss: 0.0698 - genM_loss: 11.1103 - genP_loss: 11.5427 Epoch 2/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0129 - disP_loss: 0.0150 - genM_loss: 10.5531 - genP_loss: 10.8385 Epoch 3/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0213 - disP_loss: 0.0123 - genM_loss: 9.9765 - genP_loss: 10.4574 Epoch 4/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0385 - disP_loss: 0.0287 - genM_loss: 9.3544 - genP_loss: 9.7278 Epoch 5/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0923 - disP_loss: 0.0227 - genM_loss: 8.9284 - genP_loss: 9.4696 Epoch 6/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.1077 - disP_loss: 0.0417 - genM_loss: 8.7298 - genP_loss: 9.1286 Epoch 7/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 275ms/step - disM_loss: 0.1111 - disP_loss: 0.0585 - genM_loss: 8.1230 - genP_loss: 8.5678 Epoch 8/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.1210 - disP_loss: 0.0486 - genM_loss: 8.1742 - genP_loss: 8.5328 Epoch 9/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.1079 - disP_loss: 0.0888 - genM_loss: 7.8186 - genP_loss: 7.9829 Epoch 10/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0810 - disP_loss: 0.0886 - genM_loss: 7.5598 - genP_loss: 7.7959 Epoch 11/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.1208 - disP_loss: 0.0944 - genM_loss: 7.3702 - genP_loss: 7.6318 Epoch 12/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0985 - disP_loss: 0.1173 - genM_loss: 7.2601 - genP_loss: 7.3239 Epoch 13/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.1055 - disP_loss: 0.1070 - genM_loss: 7.1333 - genP_loss: 7.2801 Epoch 14/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.1186 - disP_loss: 0.0842 - genM_loss: 7.2191 - genP_loss: 7.4179 Epoch 15/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.1067 - disP_loss: 0.1295 - genM_loss: 7.1395 - genP_loss: 7.2444 Epoch 16/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.1046 - disP_loss: 0.1183 - genM_loss: 7.0189 - genP_loss: 7.1441 Epoch 17/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0896 - disP_loss: 0.1007 - genM_loss: 7.0412 - genP_loss: 7.1632 Epoch 18/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0981 - disP_loss: 0.1199 - genM_loss: 6.9059 - genP_loss: 6.9340 Epoch 19/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0792 - disP_loss: 0.1246 - genM_loss: 7.1261 - genP_loss: 7.1488 Epoch 20/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0645 - disP_loss: 0.1087 - genM_loss: 6.9861 - genP_loss: 6.8940 Epoch 21/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0847 - disP_loss: 0.0977 - genM_loss: 6.9431 - genP_loss: 7.0097 Epoch 22/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0833 - disP_loss: 0.1001 - genM_loss: 6.8667 - genP_loss: 6.8732 Epoch 23/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0814 - disP_loss: 0.0811 - genM_loss: 6.7913 - genP_loss: 6.8484 Epoch 24/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0781 - disP_loss: 0.1200 - genM_loss: 6.7949 - genP_loss: 6.8823 Epoch 25/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0895 - disP_loss: 0.1050 - genM_loss: 6.5286 - genP_loss: 6.6862 Epoch 26/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0710 - disP_loss: 0.1077 - genM_loss: 6.6754 - genP_loss: 6.5961 Epoch 27/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0648 - disP_loss: 0.1416 - genM_loss: 6.6406 - genP_loss: 6.4871 Epoch 28/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0600 - disP_loss: 0.1340 - genM_loss: 6.5313 - genP_loss: 6.4363 Epoch 29/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0690 - disP_loss: 0.1322 - genM_loss: 6.5497 - genP_loss: 6.3903 Epoch 30/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0630 - disP_loss: 0.1208 - genM_loss: 6.5231 - genP_loss: 6.4208 Epoch 31/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0563 - disP_loss: 0.0853 - genM_loss: 6.5015 - genP_loss: 6.4888 Epoch 32/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0676 - disP_loss: 0.1035 - genM_loss: 6.2169 - genP_loss: 6.1829 Epoch 33/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0625 - disP_loss: 0.1098 - genM_loss: 6.3708 - genP_loss: 6.2687 Epoch 34/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0648 - disP_loss: 0.1084 - genM_loss: 6.1040 - genP_loss: 6.0398 Epoch 35/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0535 - disP_loss: 0.1004 - genM_loss: 6.3277 - genP_loss: 6.2109 Epoch 36/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0516 - disP_loss: 0.0979 - genM_loss: 6.2374 - genP_loss: 6.1632 Epoch 37/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0540 - disP_loss: 0.0984 - genM_loss: 6.1220 - genP_loss: 6.0374 Epoch 38/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0660 - disP_loss: 0.0952 - genM_loss: 6.1781 - genP_loss: 6.1826 Epoch 39/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0509 - disP_loss: 0.0885 - genM_loss: 5.9932 - genP_loss: 5.9465 Epoch 40/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0584 - disP_loss: 0.0858 - genM_loss: 6.2598 - genP_loss: 6.2351 Epoch 41/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0452 - disP_loss: 0.0951 - genM_loss: 6.0609 - genP_loss: 5.9541 Epoch 42/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0520 - disP_loss: 0.1041 - genM_loss: 6.0167 - genP_loss: 5.8808 Epoch 43/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0587 - disP_loss: 0.0854 - genM_loss: 6.0421 - genP_loss: 6.0270 Epoch 44/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0481 - disP_loss: 0.0963 - genM_loss: 6.0180 - genP_loss: 5.9209 Epoch 45/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0540 - disP_loss: 0.0980 - genM_loss: 5.8593 - genP_loss: 5.7906 Epoch 46/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0570 - disP_loss: 0.0918 - genM_loss: 5.8051 - genP_loss: 5.8378 Epoch 47/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0512 - disP_loss: 0.0851 - genM_loss: 5.8882 - genP_loss: 5.7713 Epoch 48/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0462 - disP_loss: 0.1104 - genM_loss: 5.8982 - genP_loss: 5.6784 Epoch 49/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 274ms/step - disM_loss: 0.0442 - disP_loss: 0.0897 - genM_loss: 5.7971 - genP_loss: 5.6406 Epoch 50/50 300/300 ━━━━━━━━━━━━━━━━━━━━ 82s 273ms/step - disM_loss: 0.0467 - disP_loss: 0.1012 - genM_loss: 5.8814 - genP_loss: 5.7854
<keras.src.callbacks.history.History at 0x787f6a9c9cc0>
Results and Conclusion¶
As we can see from the training log above, the generators' loss is much larger than that of the discriminators. This is because we used different loss functions for these two types of models: while the discriminator's loss calculation is based solely on the adversarial loss, the generator's loss calculation involves three loss functions (adversarial loss, cycle consistency loss, and identity loss).
Both discriminators' losses started around 0.067, fluctuated over the 50 epochs, and ended at 0.0588 and 0.0935, respectively. Meanwhile, both generators' losses started around 11 and ended around 5.8 at epoch 50, showing a relatively stable improvement.
# Extract generator_monet
generator_monet = CycleGan.GenM
# Visualize the Photos vs. Monet-style photos
fig, ax = plt.subplots(nrows = 2, ncols = 3, figsize = (15, 10))
fig.suptitle('Photo vs. Monet-style', fontsize = 16)
for i, img in enumerate(photo_ds.take(3)):
generated_img = generator_monet(img, training = False)[0].numpy()
generated_img = (generated_img * 127.5 + 127.5).astype(np.uint8)
img = tf.cast(img * 127.5 + 127.5, tf.uint8).numpy()
img = np.squeeze(img, axis=0)
ax[0, i].imshow(img)
ax[0, i].set_title('Photo')
ax[0, i].axis('off')
ax[1, i].imshow(generated_img)
ax[1, i].set_title('Monet-style')
ax[1, i].axis('off')
plt.show()
The Monet-style photos already resemble paintings, even though we only trained the first model for 50 epochs.
# Create the folder to save generated images
os.makedirs('../images/')
# Generate monet-style images
i = 1
for img in photo_ds:
generated_img = generator_monet(img, training = False)[0].numpy()
generated_img = (generated_img * 127.5 + 127.5).astype(np.uint8)
im = PIL.Image.fromarray(generated_img)
im.save("../images/" + str(i) + '.jpg')
i += 1
print(f"Generated images: {len([name for name in os.listdir('../images') if os.path.isfile(os.path.join('../images', name))])}")
# archive the image folder
shutil.make_archive('/kaggle/working/images', 'zip', '../images')
Generated images: 7038
'/kaggle/working/images.zip'
Reference¶
Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks. 2020. arXiv:1703.10593 [cs.CV]
Amy Jang. Monet CycleGAN Tutorial. Kaggle Notebook. https://www.kaggle.com/code/amyjang/monet-cyclegan-tutorial/notebook